# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import collections
from hysop.tools.htypes import to_tuple, check_instance, first_not_None
from hysop.tools.numpywrappers import npw
from hysop.tools.io_utils import IO
from hysop.core.graph.graph import op_apply
from hysop.core.graph.node_requirements import OperatorRequirements
from hysop.core.graph.computational_graph import ComputationalGraphOperator
from hysop.constants import TranspositionState, MemoryOrdering, Backend
from hysop.fields.continuous_field import ScalarField, TensorField
from hysop.parameters.scalar_parameter import ScalarParameter
from hysop.parameters.tensor_parameter import TensorParameter
from hysop.backend.host.host_operator import HostOperatorBase
from hysop.topology.topology_descriptor import TopologyDescriptor
[docs]
class PlottingOperator(HostOperatorBase):
"""
Base operator for plotting.
"""
[docs]
@classmethod
def supports_mpi(cls):
return True
[docs]
@classmethod
def supported_backends(cls):
return Backend.all
def __new__(
cls,
name=None,
dump_dir=None,
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
figsize=(30, 18),
visu_rank=None,
fig=None,
axes=None,
force_backend=None,
**kwds,
):
return super().__new__(cls, **kwds)
def __init__(
self,
name=None,
dump_dir=None,
update_frequency=1,
save_frequency=100,
axes_shape=(1,),
figsize=(30, 18),
visu_rank=0,
fig=None,
axes=None,
force_backend=None,
**kwds,
):
import matplotlib
import matplotlib.pyplot as plt
check_instance(name, str)
check_instance(update_frequency, int, minval=0)
check_instance(save_frequency, int, minval=0)
check_instance(axes_shape, tuple, minsize=1, allow_none=True)
super().__init__(name=name, io_params=True, **kwds)
if (fig is None) ^ (axes is None):
msg = "figure and axes should be specified at the same time."
raise RuntimeError(msg)
dump_dir = first_not_None(dump_dir, IO.default_path())
imgpath = f"{dump_dir}/{name}_{{it:04d}}.png"
if fig is None:
fig, axes = plt.subplots(*axes_shape, figsize=figsize)
fig.canvas.mpl_connect("key_press_event", self.on_key_press)
fig.canvas.mpl_connect("close_event", self.on_close)
axes = npw.asarray(axes).reshape(axes_shape)
self.fig = fig
self.axes = axes
self.update_frequency = update_frequency
self.save_frequency = save_frequency
self.imgpath = imgpath
self.should_draw = visu_rank == self.mpi_params.rank
self.running = True
self.first_draw = True
self.plt = plt
self.update_ioparams = self.io_params.clone(
frequency=self.update_frequency,
io_leader=visu_rank,
visu_leader=visu_rank,
with_last=True,
)
self.save_ioparams = self.io_params.clone(
frequency=self.save_frequency,
io_leader=visu_rank,
visu_leader=visu_rank,
with_last=True,
)
td_kwds = {}
if force_backend is Backend.OPENCL:
assert "cl_env" in kwds
td_kwds["cl_env"] = kwds.pop("cl_env")
self.force_backend = force_backend
self.td_kwds = td_kwds
[docs]
def create_topology_descriptors(self):
# Here we recreate TopologyDescriptors to allow a forced backend
# like a OpenCL mapped memory backend or when we do not want
# to allocate memory for a topology that is just used for I/O.
for field, topo_descriptor in self.input_fields.items():
topo_descriptor = TopologyDescriptor.build_descriptor(
backend=self.force_backend,
operator=self,
field=field,
handle=topo_descriptor,
**self.td_kwds,
)
self.input_fields[field] = topo_descriptor
[docs]
def get_field_requirements(self):
# set good transposition state and memory ordering
requirements = super().get_field_requirements()
for is_input, ireq in requirements.iter_requirements():
if ireq is None:
continue
(field, td, req) = ireq
req.memory_order = MemoryOrdering.C_CONTIGUOUS
req.axes = (TranspositionState[field.dim].default_axes(),)
return requirements
[docs]
def get_node_requirements(self):
node_reqs = super().get_node_requirements()
node_reqs.enforce_unique_transposition_state = True
node_reqs.enforce_unique_topology_shape = False
node_reqs.enforce_unique_memory_order = True
node_reqs.enforce_unique_ghosts = False
return node_reqs
[docs]
def draw(self):
if not self.should_draw or not self.running:
return
self.fig.canvas.draw()
self.fig.show()
if self.first_draw:
self.plt.pause(1.0)
self.first_draw = False
else:
self.plt.pause(0.01)
@op_apply
def apply(self, **kwds):
self._update(**kwds)
self._save(**kwds)
def _update(self, simulation, **kwds):
if self.update_ioparams.should_dump(simulation=simulation):
self.update(simulation=simulation, **kwds)
self.draw()
def _save(self, simulation, **kwds):
if self.save_ioparams.should_dump(simulation=simulation):
self.save(simulation=simulation, **kwds)
[docs]
@abstractmethod
def update(self, **kwds):
pass
[docs]
def save(self, simulation, **kwds):
if self.should_draw:
self.fig.savefig(
self.imgpath.format(it=simulation.current_iteration),
dpi=self.fig.dpi,
bbox_inches="tight",
)
[docs]
def on_close(self, event):
self.running = False
[docs]
def on_key_press(self, event):
key = event.key
if key == "q":
self.plt.close(self.fig)
self.running = False
[docs]
class FieldPlotter2D(PlottingOperator):
"""
Base operator to plot 2D fields at runtime.
"""
def __new__(
cls,
name,
fields,
variables,
fig_title=None,
imshow_kwds=None,
add_colorbars=True,
symmetric_cbar=False,
fig=None,
axes=None,
shape=None,
**kwds,
):
return super().__new__(cls, **kwds)
def __init__(
self,
name,
fields,
variables,
fig_title=None,
imshow_kwds=None,
add_colorbars=True,
symmetric_cbar=False,
fig=None,
axes=None,
shape=None,
**kwds,
):
imshow_kwds = first_not_None(imshow_kwds, {})
imshow_kwds.setdefault("interpolation", "bilinear")
imshow_kwds.setdefault("origin", "lower")
imshow_kwds.setdefault("cmap", "bwr")
def default_figtitle(simulation):
return f"Fields at t={simulation.time}, iteration={simulation.current_iteration}"
fig_title = first_not_None(fig_title, default_figtitle)
assert callable(fig_title), "fig_title has to be a function."
if not isinstance(variables, dict):
variables = collections.defaultdict(lambda v=variables: v)
if (fig is not None) and (axes is not None):
check_instance(fields, dict, keys=matplotlib.axes.Axes, values=ScalarField)
input_fields = {p: variables[p] for p in fields.values()}
axes_shape = None
elif isinstance(fields, dict):
check_instance(fields, dict, keys=tuple, values=ScalarField)
input_fields = {p: variables[p] for p in fields.values()}
indices = npw.asarray(list(fields.keys()), dtype=npw.int32)
assert indices.shape[-1] == 2, indices.shape
assert (indices >= 0).all(), indices
axes_shape = tuple(1 + indices.max(axis=0))
axes_shape = first_not_None(shape, axes_shape)
elif isinstance(fields, (tuple, list)):
check_instance(fields, (tuple, list), values=(TensorField, ScalarField))
input_fields = {p: variables[p] for p in fields}
naxes = sum(f.nb_components for f in fields)
axes_shape = first_not_None(shape, (1, naxes))
fields = dict(zip(range(naxes), sum(map(tuple, fields), ())))
check_instance(fields, dict, keys=int, values=ScalarField)
else:
raise TypeError(fields)
assert all(
field.dim == 2 for field in input_fields.keys()
), "Fields are not 2D."
super().__init__(
name=name,
input_fields=input_fields,
axes_shape=axes_shape,
axes=axes,
fig=fig,
**kwds,
)
self.fig.canvas.set_window_title("HySoP Field Plotter")
self._plt_cfields = fields
self._plt_dfields = None
self._imshow_handles = None
self._imshow_kwds = imshow_kwds
self._add_colorbars = add_colorbars
self._symmetric_cbar = symmetric_cbar
self._fig_title = fig_title
[docs]
def discretize(self):
if self.discretized:
return
super().discretize()
if not self.should_draw:
return
self._plt_dfields = {}
self._imshow_handles = {}
for axis_key, input_field in self._plt_cfields.items():
discrete_field = self.get_input_discrete_field(input_field)
for scalar_dfield in discrete_field:
if isinstance(axis_key, int):
axis = self.axes.ravel()[axis_key]
elif isinstance(axis_key, tuple):
axis = self.axes[axis_key]
else:
axis = axis_key
self._plt_dfields[axis] = scalar_dfield
axis.set_title(scalar_dfield.name)
axis.set_xlabel("x")
axis.set_ylabel("y")
data = npw.zeros(
shape=discrete_field.mesh.grid_resolution,
dtype=discrete_field.dtype,
)
if self.should_draw:
box = scalar_dfield.domain
extent = (box.origin[1], box.end[1], box.origin[0], box.end[0])
img = axis.imshow(data, extent=extent, **self._imshow_kwds)
self._imshow_handles[scalar_dfield] = img
if self._add_colorbars:
from mpl_toolkits.axes_grid1 import make_axes_locatable
divider = make_axes_locatable(axis)
cax = divider.append_axes("right", size="5%", pad=0.05)
self.fig.colorbar(img, cax=cax, orientation="vertical")
else:
self._imshow_handles[scalar_dfield] = None
[docs]
def update(self, simulation, **kwds):
self.fig.suptitle(self._fig_title(simulation))
for dfield, handle in self._imshow_handles.items():
data = dfield.collect_data(leader=self.update_ioparams.io_leader)
self._update_imshow_handle(handle, data)
def _update_imshow_handle(self, handle, data):
if handle is None:
return
handle.set_data(data)
dmin, dmax = data.min(), data.max()
dinf = max(abs(dmin), abs(dmax))
if self._symmetric_cbar:
handle.set_clim(-dinf, +dinf)
else:
handle.set_clim(+dmin, +dmax)
[docs]
class ParameterPlotter(PlottingOperator):
"""
Base operator to plot parameters during runtime.
"""
def __init__(
self, name, parameters, alloc_size=128, fig=None, axes=None, shape=None, **kwds
):
input_params = set()
if (fig is not None) and (axes is not None):
import matplotlib
custom_axes = True
axes_shape = None
check_instance(parameters, dict, keys=matplotlib.axes.Axes, values=dict)
for params in parameters.values():
check_instance(params, dict, keys=str, values=ScalarParameter)
input_params.update(set(params.values()))
else:
custom_axes = False
_parameters = {}
if isinstance(parameters, TensorParameter):
_parameters[0] = parameters
elif isinstance(parameters, (list, tuple)):
for i, p in enumerate(parameters):
_parameters[i] = p
elif isinstance(parameters, dict):
_parameters = parameters.copy()
else:
raise TypeError(type(parameters))
check_instance(
_parameters,
dict,
keys=(int, tuple, list),
values=(TensorParameter, list, tuple, dict),
)
parameters = {}
axes_shape = (1,) * 2
for pos, params in _parameters.items():
pos = to_tuple(pos)
pos = (2 - len(pos)) * (0,) + pos
check_instance(pos, tuple, values=int)
axes_shape = tuple(max(p0, p1 + 1) for (p0, p1) in zip(axes_shape, pos))
if isinstance(params, dict):
input_params.update({p.name: p for p in params.values()})
elif isinstance(params, TensorParameter):
input_params[params.name] = params
params = {params.name: params}
elif isinstance(params, (list, tuple)):
for p in params:
input_params[p.name] = p
params = {p.name: p for p in params}
else:
raise TypeError(type(params))
check_instance(params, dict, keys=str, values=TensorParameter)
_params = {}
for pname, p in params.items():
if isinstance(p, ScalarParameter):
_params[pname] = p
else:
for idx in npw.ndindex(*p.shape):
_pname = pname + f"_{idx}"
_p = p.view(idx)
_params[_pname] = _p
parameters[pos] = _params
super().__init__(
name=name,
input_params=input_params,
axes_shape=axes_shape,
axes=axes,
fig=fig,
**kwds,
)
self.custom_axes = custom_axes
data = {}
lines = {}
times = npw.empty(shape=(alloc_size,), dtype=npw.float32)
for pos, params in parameters.items():
params_data = {}
params_lines = {}
for pname, p in params.items():
pdata = npw.empty(shape=(alloc_size,), dtype=p.dtype)
pline = self.get_axes(pos).plot([], [], label=pname)[0]
params_data[p] = pdata
params_lines[p] = pline
data[pos] = params_data
lines[pos] = params_lines
self.fig.canvas.set_window_title("HySoP Parameter Plotter")
self.parameters = parameters
self.times = times
self.data = data
self.lines = lines
self.alloc_size = alloc_size
self.counter = 0
[docs]
def get_axes(self, pos):
axes = self.axes
if self.custom_axes:
return pos
else:
return axes[pos]
def __getitem__(self, i):
if self.custom_axes:
return self.axes[i]
else:
return self.axes.flatten()[i]
[docs]
def update(self, simulation, **kwds):
# expand memory if required
if self.counter + 1 > self.times.size:
times = npw.empty(shape=(2 * self.times.size,), dtype=self.times.dtype)
times[: self.times.size] = self.times
self.times = times
for pos, params in self.data.items():
for p, pdata in params.items():
new_pdata = npw.empty(shape=(2 * pdata.size,), dtype=pdata.dtype)
new_pdata[: pdata.size] = pdata
params[p] = new_pdata
times, data, lines = self.times, self.data, self.lines
times[self.counter] = simulation.t()
for pos, params in self.parameters.items():
for pname, p in params.items():
data[pos][p][self.counter] = p()
lines[pos][p].set_xdata(times[: self.counter])
lines[pos][p].set_ydata(data[pos][p][: self.counter])
self.counter += 1